/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/

/**
 */

#include <algorithm>
#include <iostream>

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"

// The code section below describes datatype for input, output matrices and
// computation between elements in input matrices.
using ElementAccumulator = float;  // <- data type of accumulator
using ElementComputeEpilogue =
        ElementAccumulator;  // <- data type of epilogue operations
using ElementInputA =
        cutlass::half_t;  // <- data type of elements in input matrix A
using ElementInputB =
        cutlass::half_t;      // <- data type of elements in input matrix B
using ElementOutput = float;  // <- data type of elements in output matrix D

// The code section below describes matrix layout of input and output matrices.
// Column Major for Matrix A, Row Major for Matrix B and Row Major for Matrix C
using LayoutInputA = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;

// This code section describes whether you want to use tensor cores or regular
// SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;

// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm75;

// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock =
        cutlass::gemm::GemmShape<128, 128, 32>;  // <- threadblock tile M = 128,
                                                 // N = 128, K = 32
// This code section describes tile size a warp will compute
using ShapeMMAWarp =
        cutlass::gemm::GemmShape<64, 64,
                                 32>;  // <- warp tile M = 64, N = 64, K = 32
// This code section describes the size of MMA op
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>;  // <- MMA Op tile M = 8,
                                                        // N = 8, K = 4

// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock =
        cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;  // <- ??

// Define the epilogue operation as LinearCombinationRelu. This is approximately
// equal to
//
//    d_ij = max(0, alpha * sum_k(a_ik * b_kj) + beta * c_ij )
//
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu<
        ElementOutput,  // <- data type of output matrix
        128 / cutlass::sizeof_bits<ElementOutput>::
                        value,    // <- this is the number of elements per
                                  // vectorized memory access. For half
                                  // precision, it's 8 elements. This becomes
                                  // the vector width of math instructions in
                                  // epilogue too
        ElementAccumulator,       // <- data type of accumulator
        ElementComputeEpilogue>;  // <- data type for alpha/beta in linear
                                  // combination function

// Number of pipelines you want to use
constexpr int NumStages = 2;

using Gemm = cutlass::gemm::device::Gemm<
        ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput,
        LayoutOutput, ElementAccumulator, MMAOp, SmArch, ShapeMMAThreadBlock,
        ShapeMMAWarp, ShapeMMAOp, EpilogueOp, SwizzleThreadBlock, NumStages>;

int run() {
    const int length_m = 5120;
    const int length_n = 4096;
    const int length_k = 4096;

    // Create a tuple of problem size for matrix multiplication
    cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);

    // Initialize tensors using CUTLASS helper functions
    cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
            problem_size.mk());  // <- Create matrix A with dimensions M x K
    cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
            problem_size.kn());  // <- Create matrix B with dimensions K x N

    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c_bias(
            {problem_size.m(), 1});  // <- Create matrix C with dimensions M x 1

    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
            problem_size.mn());  // <- Create matrix D with dimensions M x N
                                 // used to store output from CUTLASS kernel
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
            problem_size.mn());  // <- Create matrix D with dimensions M x N
                                 // used to store output from reference kernel

    // Fill input and output matrices on host using CUTLASS helper functions
    cutlass::reference::host::TensorFillRandomUniform(
            tensor_a.host_view(), 1, ElementInputA(4), ElementInputA(-4),
            0);  // <- Fill matrix A on host with uniform-distribution random
                 // data
    cutlass::reference::host::TensorFillRandomUniform(
            tensor_b.host_view(), 1, ElementInputB(4), ElementInputB(-4),
            0);  // <- Fill matrix B on host with uniform-distribution random
                 // data
    cutlass::reference::host::TensorFillRandomUniform(
            tensor_c_bias.host_view(), 1, ElementOutput(4), ElementOutput(-4),
            0);  // <- Fill matrix C on host with uniform-distribution random
                 // data
    cutlass::reference::host::TensorFill(
            tensor_d.host_view());  // <- fill matrix D on host with zeros
    cutlass::reference::host::TensorFill(
            tensor_ref_d.host_view());  // <- fill matrix D for reference on
                                        // host with zeros

    // Copy data from host to GPU
    tensor_a.sync_device();
    tensor_b.sync_device();
    tensor_c_bias.sync_device();
    tensor_d.sync_device();
    tensor_ref_d.sync_device();

    // Initialize alpha and beta for dot product computation
    ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
    ElementComputeEpilogue beta = ElementComputeEpilogue(0);

    // Split K dimension into 1 partitions
    int split_k_slices = 1;

    // Create a tuple of gemm kernel arguments. This is later passed as
    // arguments to launch instantiated CUTLASS kernel
    typename Gemm::Arguments arguments{
            problem_size,           // <- problem size of matrix multiplication
            tensor_a.device_ref(),  // <- reference to matrix A on device
            tensor_b.device_ref(),  // <- reference to matrix B on device

            {tensor_c_bias.device_data(),
             0},  // <- the C matrix is treated as the bias vector. We can
                  // enable the GEMM
                  //    to project away the N dimension by setting the stride to
                  //    zero.

            tensor_d.device_ref(),  // <- reference to matrix D on device
            {alpha, beta},          // <- tuple of alpha and beta
            split_k_slices};        // <- k-dimension split factor

    // Using the arguments, query for extra workspace required for matrix
    // multiplication computation
    size_t workspace_size = Gemm::get_workspace_size(arguments);

    // Allocate workspace memory
    cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

    // Instantiate CUTLASS kernel depending on templates
    Gemm gemm_op;

    // Initialize CUTLASS kernel with arguments and workspace pointer
    cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
    CUTLASS_CHECK(status);

    // Launch initialized CUTLASS kernel
    status = gemm_op();
    CUTLASS_CHECK(status);

    //
    // Create instantiation for device reference gemm kernel
    //

    cutlass::reference::device::Gemm<ElementInputA, LayoutInputA, ElementInputB,
                                     LayoutInputB, ElementOutput, LayoutOutput,
                                     ElementComputeEpilogue,
                                     ElementComputeEpilogue>
            gemm_device_reference;

    // Launch device reference to compute strictly the product A * B
    gemm_device_reference(problem_size, alpha, tensor_a.device_ref(),
                          tensor_b.device_ref(), 0, tensor_ref_d.device_ref());

    // Wait for kernels to finish
    cudaDeviceSynchronize();

    // Copy output data from CUTLASS and reference kernel to host for comparison
    tensor_d.sync_host();
    tensor_ref_d.sync_host();

    // Compute bias + relu in host code
    for (int i = 0; i < problem_size.m(); ++i) {
        for (int j = 0; j < problem_size.n(); ++j) {
            tensor_ref_d.at({i, j}) =
                    std::max(ElementOutput(0),
                             ElementOutput(tensor_ref_d.at({i, j}) +
                                           beta * tensor_c_bias.at({i, 0})));
        }
    }

    // Check if output from CUTLASS kernel and reference kernel are equal or not
    std::cout << (cutlass::reference::host::TensorEquals(
                          tensor_d.host_view(), tensor_ref_d.host_view())
                          ? "Passed"
                          : "Failed")
              << std::endl;

    CUTLASS_CHECK(status);
    return 0;
}

int main() {
    bool notSupported = false;

    // Turing Tensor Core operations exposed with mma.sync are first available
    // in CUDA 10.2.
    //
    // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
    if (!(__CUDACC_VER_MAJOR__ > 10 ||
          (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
        std::cerr << "Turing Tensor Core operations must be compiled with CUDA "
                     "10.2 Toolkit or later."
                  << std::endl;
        notSupported = true;
    }

    cudaDeviceProp props;

    cudaError_t error = cudaGetDeviceProperties(&props, 0);
    if (error != cudaSuccess) {
        std::cerr << "cudaGetDeviceProperties() returned an error: "
                  << cudaGetErrorString(error) << std::endl;
        return -1;
    }

    if (!(props.major * 10 + props.minor >= 75)) {
        std::cerr << "Turing Tensor Ops must be run on a machine with compute "
                     "capability at least 75."
                  << std::endl;
        notSupported = true;
    }

    if (notSupported) {
        // Returning zero so this test passes on older Toolkits. Its actions are
        // no-op.
        return 0;
    }

    return run();
}
